import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("/tmp/data/", one_hot=False)

# hyper parameters
batch_size = 32
save_path = 'model'
max_train_epoch = 5
lr = 0.001
n_inputs = 784
classes = 10
hidden1 = 256
hidden2 = 2
d_hidden1 = 256
d_hidden2 = 784
example_to_show = 10


# network structure
class nn(object):
    def __init__(self, inputs, name='nn', trainning=True, reuse=False):
        with tf.variable_scope(name, reuse=reuse):
            self.global_step = tf.Variable(0, trainable=False, dtype=tf.int32)
            with tf.variable_scope('encoder'):
                with tf.variable_scope('hidden1'):
                    hd1 = tf.layers.dense(inputs, hidden1, activation=tf.nn.sigmoid, name='hidden1')
                with tf.variable_scope('hidden2'):
                    hd2 = tf.layers.dense(hd1, hidden2, activation=tf.nn.sigmoid, name='hidden2')
            with tf.variable_scope('decoder'):
                with tf.variable_scope('d_hidden1'):
                    d_hd1 = tf.layers.dense(hd2, d_hidden1, activation=tf.nn.sigmoid, name='d_hidden1')
                with tf.variable_scope('d_hidden2'):
                    d_hd2 = tf.layers.dense(d_hd1, d_hidden2, activation=tf.nn.sigmoid, name='d_hidden1')
            self.y_pred = d_hd2
            self.y_gt = inputs
            self.encoder_vec = hd2
            with tf.variable_scope('loss'):
                self.loss = tf.reduce_mean(tf.square(self.y_pred - self.y_gt))

    def summary(self):
        tf.summary.scalar('loss', self.loss)


# placehoder
input_x = tf.placeholder(dtype=tf.float32, shape=(None, n_inputs), name='inputs')
labels = tf.placeholder(dtype=tf.float32, shape=(None, classes), name='labels')

# model
train_model = nn(input_x)

# opt
train_up = tf.train.AdamOptimizer(lr).minimize(train_model.loss, train_model.global_step)

# save
saver = tf.train.Saver()

with tf.Session() as sess:
    # recoder, summary
    train_model.summary()
    train_writer = tf.summary.FileWriter('log', graph=sess.graph)
    merged = tf.summary.merge_all()

    # restore or initail
    ckpt = tf.train.get_checkpoint_state(save_path)
    if ckpt:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        saver.restore(sess, os.path.join(save_path, ckpt_name))
    else:
        sess.run(tf.global_variables_initializer())

    # circulation start
    global_step_val = sess.run(train_model.global_step)
    tot_batch = int(mnist.train.num_examples / batch_size)
    now_epoch = int(global_step_val / tot_batch)
    while now_epoch < max_train_epoch:
        print('epoch:', now_epoch)
        for i in range(tot_batch):
            batch_images, batch_labels = mnist.train.next_batch(batch_size)
            _, tmp_loss = sess.run([train_up, train_model.loss], feed_dict={input_x: batch_images})
            global_step_val += 1
            if global_step_val % 100 == 0:
                saver.save(sess, os.path.join(save_path, 'nn.ckpt'), global_step_val)
                merged_summary = sess.run(merged, feed_dict={input_x: mnist.test.images[:batch_size]})
                train_writer.add_summary(merged_summary, global_step_val)
        epoch_loss = tmp_loss
        print('loss', epoch_loss)
        now_epoch += 1

    result = sess.run(train_model.y_pred, feed_dict={input_x: mnist.test.images[:example_to_show]})
    f, a = plt.subplots(2, 10, figsize=(10, 2))
    # f, a = plt.subplots(2, 10, figsize=(10, 2))
    for i in range(10):
        a[0][i].imshow(np.reshape(mnist.test.images[i], (28, 28)))
        a[1][i].imshow(np.reshape(result[i], (28, 28)))
    plt.show()
    all_vec = sess.run(train_model.encoder_vec, feed_dict={input_x: mnist.test.images})
    plt.scatter(all_vec[:, 0], all_vec[:, 1], c=mnist.test.labels)
    plt.colorbar()
    plt.show()
